{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "The XLA compile API",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "metadata": {
        "colab_type": "text",
        "id": "f4TSNCvpENrW"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "vamNSA0vEP-m",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "e1oSi4lHFt3z"
      },
      "cell_type": "markdown",
      "source": [
        "# The XLA compile API"
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "b7noD9NjFRL-"
      },
      "cell_type": "markdown",
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/xla_compile\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "v9YbsuLZaBXy"
      },
      "cell_type": "markdown",
      "source": [
        "\n",
        "\n",
        "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "45kUPj5ZFrRa",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.contrib.compiler import xla"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "GZVNiRmTDV-5"
      },
      "cell_type": "markdown",
      "source": [
        "Define some necessary constants and prepare the MNIST dataset."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "f37TSEGvGX4_",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Size of each input image, 28 x 28 pixels\n",
        "IMAGE_SIZE = 28 * 28\n",
        "# Number of distinct number labels, [0..9]\n",
        "NUM_CLASSES = 10\n",
        "# Number of examples in each training batch (step)\n",
        "TRAIN_BATCH_SIZE = 100\n",
        "# Number of training steps to run\n",
        "TRAIN_STEPS = 1000"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "TiVXchblG5hK",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Loads MNIST dataset.\n",
        "train, test = tf.keras.datasets.mnist.load_data()\n",
        "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n",
        "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n",
        "\n",
        "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n",
        "images, labels = iterator.get_next()\n",
        "images = tf.reshape(images, [-1, IMAGE_SIZE])\n",
        "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "x_ZehpZP-SfS"
      },
      "cell_type": "markdown",
      "source": [
        "# Define the model constructing function\n",
        "\n",
        "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n",
        "\n",
        "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "ZbhJl_WvGa3g",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def build_mnist_model(x, y_):\n",
        "  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n",
        "\n",
        "  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n",
        "  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n",
        "\n",
        "  return y, train_step"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "7Jh3lyQHDfM9"
      },
      "cell_type": "markdown",
      "source": [
        "# Enable XLA\n",
        "\n",
        "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "kYpCXCdRHNuN",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "[y] = xla.compile(build_mnist_model, inputs=[images, labels])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "4giQh62IrZGF"
      },
      "cell_type": "markdown",
      "source": [
        "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n",
        "\n",
        "xla.compile does not return any\n",
        "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n",
        "\n",
        "In pseudo-code, xla.compile's implementation looks as follows:\n",
        "\n",
        "---\n",
        "```\n",
        "# Ask Tensorflow to execute code in XLA-friendly manner\n",
        "\n",
        "y, train_step = build_mnist_model(images, labels)\n",
        "with tf.control_dependencies([train_step]):\n",
        "  y = tf.identity(y)\n",
        "\n",
        "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n",
        "```\n",
        "---\n",
        "\n",
        "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)."
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "TPGas4jjFLZl"
      },
      "cell_type": "markdown",
      "source": [
        "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`.  At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready."
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "EZD1m_n1DxAF"
      },
      "cell_type": "markdown",
      "source": [
        "# Train and test the model"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "qe28bAHNHUG2",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Creates session and initialize all variables.\n",
        "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n",
        "sess = tf.Session()\n",
        "sess.run(tf.global_variables_initializer())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "qgsKmz3n2UiW"
      },
      "cell_type": "markdown",
      "source": [
        "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "_GxF6jTRHVuA",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132"
      },
      "cell_type": "code",
      "source": [
        "# Feeds training dataset\n",
        "sess.run(iterator.make_initializer(train_ds))\n",
        "\n",
        "# Runs TRAIN_STEPS steps\n",
        "for i in range(TRAIN_STEPS):\n",
        "  sess.run(y)\n",
        "\n",
        "print(\"Model trained for %s steps.\" % TRAIN_STEPS)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model trained for 1000 steps.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "dHlQlRSRHXD1",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093"
      },
      "cell_type": "code",
      "source": [
        "# Tests trained model\n",
        "\n",
        "# Feeds testing dataset\n",
        "sess.run(iterator.make_initializer(test_ds))\n",
        "\n",
        "# Calculates accuracy\n",
        "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n",
        "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
        "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Prediction accuracy after training: 0.91\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "ynJQIuzjHYOb",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Cleans up session\n",
        "sess.close()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}